//
// Copyright (c) 2018-2020 CNRS INRIA
//

#ifndef __pinocchio_algorithm_regressor_hxx__
#define __pinocchio_algorithm_regressor_hxx__

#include "pinocchio/algorithm/check.hpp"
#include "pinocchio/algorithm/kinematics.hpp"
#include "pinocchio/spatial/skew.hpp"
#include "pinocchio/spatial/symmetric3.hpp"

namespace pinocchio
{

  namespace internal
  {
    template<
      typename Scalar,
      int Options,
      template<typename, int> class JointCollectionTpl,
      typename Matrix6xReturnType>
    void computeJointKinematicRegressorGeneric(
      const ModelTpl<Scalar, Options, JointCollectionTpl> & model,
      const DataTpl<Scalar, Options, JointCollectionTpl> & data,
      const JointIndex joint_id,
      const ReferenceFrame rf,
      const SE3Tpl<Scalar, Options> & global_frame_placement,
      const Eigen::MatrixBase<Matrix6xReturnType> & kinematic_regressor)
    {
      assert(model.check(data) && "data is not consistent with model.");
      PINOCCHIO_CHECK_ARGUMENT_SIZE(kinematic_regressor.rows(), 6);
      PINOCCHIO_CHECK_ARGUMENT_SIZE(kinematic_regressor.cols(), 6 * (model.njoints - 1));

      typedef DataTpl<Scalar, Options, JointCollectionTpl> Data;
      typedef typename Data::SE3 SE3;

      Matrix6xReturnType & kinematic_regressor_ = kinematic_regressor.const_cast_derived();
      kinematic_regressor_.setZero();

      const SE3Tpl<Scalar, Options> & oMi = global_frame_placement;
      SE3 oMp; // placement of the frame following the jointPlacement transform
      SE3 iMp; // relative placement between the joint frame and the jointPlacement
      for (JointIndex i = joint_id; i > 0; i = model.parents[i])
      {
        const JointIndex parent_id = model.parents[i];
        oMp = data.oMi[parent_id] * model.jointPlacements[i];
        switch (rf)
        {
        case LOCAL:
          iMp = oMi.actInv(oMp);
          kinematic_regressor_.template middleCols<6>((Eigen::DenseIndex)(6 * (i - 1))) =
            iMp.toActionMatrix(); // TODO: we can avoid a copy
          break;
        case LOCAL_WORLD_ALIGNED:
          iMp.rotation() = oMp.rotation();
          iMp.translation() = oMp.translation() - oMi.translation();
          kinematic_regressor_.template middleCols<6>((Eigen::DenseIndex)(6 * (i - 1))) =
            iMp.toActionMatrix(); // TODO: we can avoid a copy
          break;
        case WORLD:
          kinematic_regressor_.template middleCols<6>((Eigen::DenseIndex)(6 * (i - 1))) =
            oMp.toActionMatrix(); // TODO: we can avoid a copy
          break;
        }
      }
    }
  } // namespace internal

  template<
    typename Scalar,
    int Options,
    template<typename, int> class JointCollectionTpl,
    typename Matrix6xReturnType>
  void computeJointKinematicRegressor(
    const ModelTpl<Scalar, Options, JointCollectionTpl> & model,
    const DataTpl<Scalar, Options, JointCollectionTpl> & data,
    const JointIndex joint_id,
    const ReferenceFrame rf,
    const Eigen::MatrixBase<Matrix6xReturnType> & kinematic_regressor)
  {
    PINOCCHIO_CHECK_INPUT_ARGUMENT(joint_id > 0 && (Eigen::DenseIndex)joint_id < model.njoints);
    internal::computeJointKinematicRegressorGeneric(
      model, data, joint_id, rf, data.oMi[joint_id], kinematic_regressor.const_cast_derived());
  }

  template<
    typename Scalar,
    int Options,
    template<typename, int> class JointCollectionTpl,
    typename Matrix6xReturnType>
  void computeJointKinematicRegressor(
    const ModelTpl<Scalar, Options, JointCollectionTpl> & model,
    const DataTpl<Scalar, Options, JointCollectionTpl> & data,
    const JointIndex joint_id,
    const ReferenceFrame rf,
    const SE3Tpl<Scalar, Options> & placement,
    const Eigen::MatrixBase<Matrix6xReturnType> & kinematic_regressor)
  {
    PINOCCHIO_CHECK_INPUT_ARGUMENT(joint_id > 0 && (Eigen::DenseIndex)joint_id < model.njoints);

    typedef DataTpl<Scalar, Options, JointCollectionTpl> Data;
    typedef typename Data::SE3 SE3;

    const SE3 global_placement = data.oMi[joint_id] * placement;

    internal::computeJointKinematicRegressorGeneric(
      model, data, joint_id, rf, global_placement, kinematic_regressor.const_cast_derived());
  }

  template<
    typename Scalar,
    int Options,
    template<typename, int> class JointCollectionTpl,
    typename Matrix6xReturnType>
  void computeFrameKinematicRegressor(
    const ModelTpl<Scalar, Options, JointCollectionTpl> & model,
    DataTpl<Scalar, Options, JointCollectionTpl> & data,
    const FrameIndex frame_id,
    const ReferenceFrame rf,
    const Eigen::MatrixBase<Matrix6xReturnType> & kinematic_regressor)
  {
    PINOCCHIO_CHECK_INPUT_ARGUMENT(frame_id > 0 && (Eigen::DenseIndex)frame_id < model.nframes);

    typedef ModelTpl<Scalar, Options, JointCollectionTpl> Model;
    typedef typename Model::Frame Frame;

    const Frame & frame = model.frames[frame_id];
    data.oMf[frame_id] = data.oMi[frame.parentJoint] * frame.placement;

    internal::computeJointKinematicRegressorGeneric(
      model, data, frame.parentJoint, rf, data.oMf[frame_id],
      kinematic_regressor.const_cast_derived());
  }

  template<
    typename Scalar,
    int Options,
    template<typename, int> class JointCollectionTpl,
    typename ConfigVectorType>
  inline typename DataTpl<Scalar, Options, JointCollectionTpl>::Matrix3x & computeStaticRegressor(
    const ModelTpl<Scalar, Options, JointCollectionTpl> & model,
    DataTpl<Scalar, Options, JointCollectionTpl> & data,
    const Eigen::MatrixBase<ConfigVectorType> & q)
  {
    assert(model.check(data) && "data is not consistent with model.");
    PINOCCHIO_CHECK_ARGUMENT_SIZE(q.size(), model.nq);

    typedef ModelTpl<Scalar, Options, JointCollectionTpl> Model;
    typedef DataTpl<Scalar, Options, JointCollectionTpl> Data;
    typedef typename Model::JointIndex JointIndex;
    typedef typename Data::SE3 SE3;

    typedef typename Data::Matrix3x Matrix3x;
    typedef typename SizeDepType<4>::ColsReturn<Matrix3x>::Type ColsBlock;

    forwardKinematics(model, data, q.derived());

    // Computes the total mass of the system
    Scalar mass = Scalar(0);
    for (JointIndex i = 1; i < (JointIndex)model.njoints; ++i)
      mass += model.inertias[(JointIndex)i].mass();

    const Scalar mass_inv = Scalar(1) / mass;
    for (JointIndex i = 1; i < (JointIndex)model.njoints; ++i)
    {
      const SE3 & oMi = data.oMi[i];
      ColsBlock sr_cols =
        data.staticRegressor.template middleCols<4>((Eigen::DenseIndex)(i - 1) * 4);
      sr_cols.col(0) = oMi.translation();
      sr_cols.template rightCols<3>() = oMi.rotation();
      sr_cols *= mass_inv;
    }

    return data.staticRegressor;
  }

  namespace details
  {
    // auxiliary function for bodyRegressor: bigL(omega)*I.toDynamicParameters().tail<6>() =
    // I.inertia() * omega
    /*
        template<typename Vector3Like>
        inline Eigen::Matrix<typename
       Vector3Like::Scalar,3,6,PINOCCHIO_EIGEN_PLAIN_TYPE(Vector3Like)::Options> bigL(const
       Eigen::MatrixBase<Vector3Like> & omega)
        {
          typedef typename Vector3Like::Scalar Scalar;
          enum { Options = PINOCCHIO_EIGEN_PLAIN_TYPE(Vector3Like)::Options };
          typedef Eigen::Matrix<Scalar,3,6,Options> ReturnType;

          ReturnType L;
          L <<  omega[0],  omega[1], Scalar(0),  omega[2], Scalar(0), Scalar(0),
               Scalar(0),  omega[0],  omega[1], Scalar(0),  omega[2], Scalar(0),
               Scalar(0), Scalar(0), Scalar(0),  omega[0],  omega[1],  omega[2];
          return L;
        }
    */

    // auxiliary function for bodyRegressor: res += bigL(omega)
    template<typename Vector3Like, typename OutputType>
    inline void
    addBigL(const Eigen::MatrixBase<Vector3Like> & omega, const Eigen::MatrixBase<OutputType> & out)
    {
      OutputType & res = PINOCCHIO_EIGEN_CONST_CAST(OutputType, out);
      res(0, 0) += omega[0];
      res(0, 1) += omega[1];
      res(0, 3) += omega[2];
      res(1, 1) += omega[0];
      res(1, 2) += omega[1];
      res(1, 4) += omega[2];
      res(2, 3) += omega[0];
      res(2, 4) += omega[1];
      res(2, 5) += omega[2];
    }

    // auxiliary function for bodyRegressor: res = cross(omega,bigL(omega))
    template<typename Vector3Like, typename OutputType>
    inline void
    crossBigL(const Eigen::MatrixBase<Vector3Like> & v, const Eigen::MatrixBase<OutputType> & out)
    {
      typedef typename Vector3Like::Scalar Scalar;
      OutputType & res = PINOCCHIO_EIGEN_CONST_CAST(OutputType, out);

      res << Scalar(0), -v[2] * v[0], -v[2] * v[1], v[1] * v[0], v[1] * v[1] - v[2] * v[2],
        v[2] * v[1], v[2] * v[0], v[2] * v[1], Scalar(0), v[2] * v[2] - v[0] * v[0], -v[1] * v[0],
        -v[2] * v[0], -v[1] * v[0], v[0] * v[0] - v[1] * v[1], v[1] * v[0], -v[2] * v[1],
        v[2] * v[0], Scalar(0);
    }
  } // namespace details

  template<typename MotionVelocity, typename MotionAcceleration, typename OutputType>
  inline void bodyRegressor(
    const MotionDense<MotionVelocity> & v,
    const MotionDense<MotionAcceleration> & a,
    const Eigen::MatrixBase<OutputType> & regressor)
  {
    PINOCCHIO_ASSERT_MATRIX_SPECIFIC_SIZE(OutputType, regressor, 6, 10);

    typedef typename MotionVelocity::Scalar Scalar;
    enum
    {
      Options = PINOCCHIO_EIGEN_PLAIN_TYPE(typename MotionVelocity::Vector3)::Options
    };

    typedef Symmetric3Tpl<Scalar, Options> Symmetric3;
    typedef typename Symmetric3::SkewSquare SkewSquare;
    using ::pinocchio::details::addBigL;
    using ::pinocchio::details::crossBigL;

    OutputType & res = PINOCCHIO_EIGEN_CONST_CAST(OutputType, regressor);

    res.template block<3, 1>(MotionVelocity::LINEAR, 0) =
      a.linear() + v.angular().cross(v.linear());
    const Eigen::Block<OutputType, 3, 1> & acc =
      res.template block<3, 1>(MotionVelocity::LINEAR, 0);
    res.template block<3, 3>(MotionVelocity::LINEAR, 1) =
      Symmetric3(SkewSquare(v.angular())).matrix();
    addSkew(a.angular(), res.template block<3, 3>(MotionVelocity::LINEAR, 1));

    res.template block<3, 6>(MotionVelocity::LINEAR, 4).setZero();

    res.template block<3, 1>(MotionVelocity::ANGULAR, 0).setZero();
    skew(-acc, res.template block<3, 3>(MotionVelocity::ANGULAR, 1));
    // res.template block<3,6>(MotionVelocity::ANGULAR,4) = bigL(a.angular()) + cross(v.angular(),
    // bigL(v.angular()));
    crossBigL(v.angular(), res.template block<3, 6>(MotionVelocity::ANGULAR, 4));
    addBigL(a.angular(), res.template block<3, 6>(MotionVelocity::ANGULAR, 4));
  }

  template<typename MotionVelocity, typename MotionAcceleration>
  inline Eigen::Matrix<
    typename MotionVelocity::Scalar,
    6,
    10,
    PINOCCHIO_EIGEN_PLAIN_TYPE(typename MotionVelocity::Vector3)::Options>
  bodyRegressor(const MotionDense<MotionVelocity> & v, const MotionDense<MotionAcceleration> & a)
  {
    typedef typename MotionVelocity::Scalar Scalar;
    enum
    {
      Options = PINOCCHIO_EIGEN_PLAIN_TYPE(typename MotionVelocity::Vector3)::Options
    };
    typedef Eigen::Matrix<Scalar, 6, 10, Options> ReturnType;

    ReturnType res;
    bodyRegressor(v, a, res);
    return res;
  }

  template<typename Scalar, int Options, template<typename, int> class JointCollectionTpl>
  inline typename DataTpl<Scalar, Options, JointCollectionTpl>::BodyRegressorType &
  jointBodyRegressor(
    const ModelTpl<Scalar, Options, JointCollectionTpl> & model,
    DataTpl<Scalar, Options, JointCollectionTpl> & data,
    JointIndex joint_id)
  {
    assert(model.check(data) && "data is not consistent with model.");

    PINOCCHIO_UNUSED_VARIABLE(model);

    bodyRegressor(data.v[joint_id], data.a_gf[joint_id], data.bodyRegressor);
    return data.bodyRegressor;
  }

  template<typename Scalar, int Options, template<typename, int> class JointCollectionTpl>
  inline typename DataTpl<Scalar, Options, JointCollectionTpl>::BodyRegressorType &
  frameBodyRegressor(
    const ModelTpl<Scalar, Options, JointCollectionTpl> & model,
    DataTpl<Scalar, Options, JointCollectionTpl> & data,
    FrameIndex frame_id)
  {
    assert(model.check(data) && "data is not consistent with model.");

    typedef ModelTpl<Scalar, Options, JointCollectionTpl> Model;
    typedef typename Model::Frame Frame;
    typedef typename Model::JointIndex JointIndex;
    typedef typename Model::SE3 SE3;

    const Frame & frame = model.frames[frame_id];
    const JointIndex & parent = frame.parentJoint;
    const SE3 & placement = frame.placement;

    bodyRegressor(
      placement.actInv(data.v[parent]), placement.actInv(data.a_gf[parent]), data.bodyRegressor);
    return data.bodyRegressor;
  }

  template<
    typename Scalar,
    int Options,
    template<typename, int> class JointCollectionTpl,
    typename ConfigVectorType,
    typename TangentVectorType1,
    typename TangentVectorType2>
  struct JointTorqueRegressorForwardStep
  : public fusion::JointUnaryVisitorBase<JointTorqueRegressorForwardStep<
      Scalar,
      Options,
      JointCollectionTpl,
      ConfigVectorType,
      TangentVectorType1,
      TangentVectorType2>>
  {
    typedef ModelTpl<Scalar, Options, JointCollectionTpl> Model;
    typedef DataTpl<Scalar, Options, JointCollectionTpl> Data;

    typedef boost::fusion::vector<
      const Model &,
      Data &,
      const ConfigVectorType &,
      const TangentVectorType1 &,
      const TangentVectorType2 &>
      ArgsType;

    template<typename JointModel>
    static void algo(
      const JointModelBase<JointModel> & jmodel,
      JointDataBase<typename JointModel::JointDataDerived> & jdata,
      const Model & model,
      Data & data,
      const Eigen::MatrixBase<ConfigVectorType> & q,
      const Eigen::MatrixBase<TangentVectorType1> & v,
      const Eigen::MatrixBase<TangentVectorType2> & a)
    {
      typedef typename Model::JointIndex JointIndex;

      const JointIndex i = jmodel.id();
      const JointIndex parent = model.parents[i];

      jmodel.calc(jdata.derived(), q.derived(), v.derived());

      data.liMi[i] = model.jointPlacements[i] * jdata.M();

      data.v[i] = jdata.v();
      if (parent > 0)
        data.v[i] += data.liMi[i].actInv(data.v[parent]);

      data.a_gf[i] = jdata.c() + (data.v[i] ^ jdata.v());
      data.a_gf[i] += jdata.S() * jmodel.jointVelocitySelector(a);
      data.a_gf[i] += data.liMi[i].actInv(data.a_gf[parent]);
    }
  };

  template<typename Scalar, int Options, template<typename, int> class JointCollectionTpl>
  struct JointTorqueRegressorBackwardStep
  : public fusion::JointUnaryVisitorBase<
      JointTorqueRegressorBackwardStep<Scalar, Options, JointCollectionTpl>>
  {
    typedef ModelTpl<Scalar, Options, JointCollectionTpl> Model;
    typedef DataTpl<Scalar, Options, JointCollectionTpl> Data;
    typedef
      typename DataTpl<Scalar, Options, JointCollectionTpl>::BodyRegressorType BodyRegressorType;

    typedef boost::fusion::vector<const Model &, Data &, const JointIndex &> ArgsType;

    template<typename JointModel>
    static void algo(
      const JointModelBase<JointModel> & jmodel,
      JointDataBase<typename JointModel::JointDataDerived> & jdata,
      const Model & model,
      Data & data,
      const JointIndex & col_idx)
    {
      typedef typename Model::JointIndex JointIndex;

      const JointIndex i = jmodel.id();
      const JointIndex parent = model.parents[i];

      data.jointTorqueRegressor.block(
        jmodel.idx_v(), 10 * (Eigen::DenseIndex(col_idx) - 1), jmodel.nv(), 10) =
        jdata.S().transpose() * data.bodyRegressor;

      if (parent > 0)
        forceSet::se3Action(data.liMi[i], data.bodyRegressor, data.bodyRegressor);
    }
  };

  template<
    typename Scalar,
    int Options,
    template<typename, int> class JointCollectionTpl,
    typename ConfigVectorType,
    typename TangentVectorType1,
    typename TangentVectorType2>
  inline typename DataTpl<Scalar, Options, JointCollectionTpl>::MatrixXs &
  computeJointTorqueRegressor(
    const ModelTpl<Scalar, Options, JointCollectionTpl> & model,
    DataTpl<Scalar, Options, JointCollectionTpl> & data,
    const Eigen::MatrixBase<ConfigVectorType> & q,
    const Eigen::MatrixBase<TangentVectorType1> & v,
    const Eigen::MatrixBase<TangentVectorType2> & a)
  {
    assert(model.check(data) && "data is not consistent with model.");
    PINOCCHIO_CHECK_ARGUMENT_SIZE(q.size(), model.nq);
    PINOCCHIO_CHECK_ARGUMENT_SIZE(v.size(), model.nv);
    PINOCCHIO_CHECK_ARGUMENT_SIZE(a.size(), model.nv);

    data.v[0].setZero();
    data.a_gf[0] = -model.gravity;
    data.jointTorqueRegressor.setZero();

    typedef JointTorqueRegressorForwardStep<
      Scalar, Options, JointCollectionTpl, ConfigVectorType, TangentVectorType1, TangentVectorType2>
      Pass1;
    typename Pass1::ArgsType arg1(model, data, q.derived(), v.derived(), a.derived());
    for (JointIndex i = 1; i < (JointIndex)model.njoints; ++i)
    {
      Pass1::run(model.joints[i], data.joints[i], arg1);
    }

    typedef JointTorqueRegressorBackwardStep<Scalar, Options, JointCollectionTpl> Pass2;
    for (JointIndex i = (JointIndex)model.njoints - 1; i > 0; --i)
    {
      jointBodyRegressor(model, data, i);

      typename Pass2::ArgsType arg2(model, data, i);
      for (JointIndex j = i; j > 0; j = model.parents[j])
      {
        Pass2::run(model.joints[j], data.joints[j], arg2);
      }
    }

    return data.jointTorqueRegressor;
  }

  template<
    typename Scalar,
    int Options,
    template<typename, int> class JointCollectionTpl,
    typename ConfigVectorType,
    typename TangentVectorType>
  const typename DataTpl<Scalar, Options, JointCollectionTpl>::RowVectorXs &
  computeKineticEnergyRegressor(
    const ModelTpl<Scalar, Options, JointCollectionTpl> & model,
    DataTpl<Scalar, Options, JointCollectionTpl> & data,
    const Eigen::MatrixBase<ConfigVectorType> & q,
    const Eigen::MatrixBase<TangentVectorType> & v)
  {
    assert(model.check(data) && "data is not consistent with model.");
    PINOCCHIO_CHECK_ARGUMENT_SIZE(q.size(), model.nq);
    PINOCCHIO_CHECK_ARGUMENT_SIZE(v.size(), model.nv);

    forwardKinematics(model, data, q.derived(), v.derived());

    data.kineticEnergyRegressor.setZero();
    // iterate over each joint and compute the kinetic energy regressor
    for (JointIndex joint_id = 1; joint_id < (JointIndex)model.njoints; ++joint_id)
    {
      // linear and angular velocities
      const auto linear_vel = data.v[joint_id].linear();
      const auto angular_vel = data.v[joint_id].angular();

      const Scalar v_x = linear_vel[0], v_y = linear_vel[1], v_z = linear_vel[2],
                   w_x = angular_vel[0], w_y = angular_vel[1], w_z = angular_vel[2];

      auto joint_regressor =
        data.kineticEnergyRegressor.template segment<10>(10 * Eigen::DenseIndex(joint_id - 1));

      joint_regressor[0] = 0.5 * linear_vel.dot(linear_vel);
      joint_regressor[1] = -w_y * v_z + w_z * v_y;
      joint_regressor[2] = w_x * v_z - w_z * v_x;
      joint_regressor[3] = -w_x * v_y + w_y * v_x;
      joint_regressor[4] = 0.5 * w_x * w_x;
      joint_regressor[5] = w_x * w_y;
      joint_regressor[6] = 0.5 * w_y * w_y;
      joint_regressor[7] = w_x * w_z;
      joint_regressor[8] = w_y * w_z;
      joint_regressor[9] = 0.5 * w_z * w_z;
    }

    return data.kineticEnergyRegressor;
  }

  template<
    typename Scalar,
    int Options,
    template<typename, int> class JointCollectionTpl,
    typename ConfigVectorType>
  const typename DataTpl<Scalar, Options, JointCollectionTpl>::RowVectorXs &
  computePotentialEnergyRegressor(
    const ModelTpl<Scalar, Options, JointCollectionTpl> & model,
    DataTpl<Scalar, Options, JointCollectionTpl> & data,
    const Eigen::MatrixBase<ConfigVectorType> & q)
  {
    assert(model.check(data) && "data is not consistent with model.");
    PINOCCHIO_CHECK_ARGUMENT_SIZE(q.size(), model.nq);
    typedef DataTpl<Scalar, Options, JointCollectionTpl> Data;

    forwardKinematics(model, data, q.derived());

    data.potentialEnergyRegressor.setZero();

    // iterate over each joint and compute the kinetic energy regressor
    for (JointIndex joint_id = 1; joint_id < (JointIndex)model.njoints; ++joint_id)
    {
      const auto & t = data.oMi[joint_id].translation();
      const auto & R = data.oMi[joint_id].rotation();
      const auto g = -model.gravity.linear();

      auto joint_regressor =
        data.potentialEnergyRegressor.template segment<10>(10 * Eigen::DenseIndex(joint_id - 1));

      const typename Data::Vector3 gravity_local = R.transpose() * g;
      joint_regressor[0] = g.dot(t);
      joint_regressor.template segment<3>(1) = gravity_local;
    }

    return data.potentialEnergyRegressor;
  }
} // namespace pinocchio

#endif // ifndef __pinocchio_algorithm_regressor_hxx__
